#!/usr/bin/env python3

# Python script to run and analyse MMS test
#
# Outputs PDF figures in each subdirectory
# Checks that the convergence is 2nd order
# Exits with status 1 if any test failed

#requires: all_tests

from __future__ import division
from __future__ import print_function
from builtins import str

from boututils.run_wrapper import shell, shell_safe, launch_safe
from boutdata.collect import collect

from numpy import sqrt, max, abs, mean, array, log

from os.path import join



print("Making MMS diffusion test")
shell_safe("make > make.log")

# List of input directories
inputs = [
    ("X", ["mesh:nx"]),
    ("Y",["mesh:ny"]),
    ("Z",["MZ"])#, 
    #("XYZ", ["mesh:nx", "mesh:ny", "MZ"])
    ]

# List of NX values to use
nxlist = [4, 8, 16, 32, 64, 128, 256]
#nxlist = [128,256,512,1024,2048,4096]

nout = 1
#timestep = 0.1
timestep = 0.1

nproc = 2

success = True

for dir,sizes in inputs:
    print("Running test in '%s'" % (dir))

    error_2   = []  # The L2 error (RMS)
    error_inf = []  # The maximum error

    for nx in nxlist:
        args = "-d "+dir+" nout="+str(nout)+" timestep="+str(timestep)
        for s in sizes:
            args += " "+s+"="+str(nx)
        
        print("  Running with " + args)

        # Delete old data
        shell("rm "+dir+"/BOUT.dmp.*.nc")
        
        # Command to run
        cmd = "./cyto "+args
        # Launch using MPI
        s, out = launch_safe(cmd, nproc=nproc, pipe=True)
        
        # Save output to log file
        f = open("run.log."+str(nx), "w")
        f.write(out)
        f.close()
        
        # Collect data
        E_N = collect("E_N", tind=[nout,nout], path=dir,info=False)
        
        # Average error over domain, not including guard cells
        l2 = sqrt(mean(E_N**2))
        linf = max(abs( E_N ))
        
        error_2.append( l2 )
        error_inf.append( linf )
        
        print("  -> Error norm: l-2 %f l-inf %f" % (l2, linf))
        
    # Calculate grid spacing
    #This is only correct in the x-direction if MXG = 1. In the other directions 
    #dy = 1/ny, dz = 1/(MZ-1)
    dx = 1. / (array(nxlist) - 2.)

    # Calculate convergence order
    
    order = log(error_2[-1] / error_2[-2]) / log(dx[-1] / dx[-2])
    print("Convergence order = %f" % (order))

    if 1.8 < order < 2.2:
        pass
    else:
        success = False
        print("=> FAILED\n")
    
    # plot errors

    try:
        import matplotlib.pyplot as plt
        plt.figure()
    
        plt.plot(dx, error_2, '-o', label=r'$l^2$')
        plt.plot(dx, error_inf, '-x', label=r'$l^\infty$')
    
        plt.plot(dx, error_2[-1]*(dx/dx[-1])**order, '--', label="Order %.1f"%(order))
    
        plt.legend(loc="upper left")
        plt.grid()
    
        plt.yscale('log')
        plt.xscale('log')
    
        plt.xlabel(r'Mesh spacing $\delta x$')
        plt.ylabel("Error norm")
    
        plt.savefig(join(dir,"norm.pdf"))

        #plt.show()
        plt.close()
    except:
        pass

if success:
  print(" => All tests passed")
  exit(0)
else:
  print(" => Some failed tests")
  exit(1)
